"""Provides device automations for MQTT."""
import logging
from typing import Callable, List, Optional

import attr
import voluptuous as vol

from homeassistant.components import mqtt
from homeassistant.components.automation import AutomationActionType
import homeassistant.components.automation.mqtt as automation_mqtt
from homeassistant.components.device_automation import TRIGGER_BASE_SCHEMA
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM, CONF_TYPE
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.typing import ConfigType, HomeAssistantType

from . import (
    ATTR_DISCOVERY_HASH,
    ATTR_DISCOVERY_TOPIC,
    CONF_CONNECTIONS,
    CONF_DEVICE,
    CONF_IDENTIFIERS,
    CONF_PAYLOAD,
    CONF_QOS,
    DOMAIN,
    cleanup_device_registry,
    debug_info,
)
from .discovery import MQTT_DISCOVERY_UPDATED, clear_discovery_hash

_LOGGER = logging.getLogger(__name__)

CONF_AUTOMATION_TYPE = "automation_type"
CONF_DISCOVERY_ID = "discovery_id"
CONF_SUBTYPE = "subtype"
CONF_TOPIC = "topic"
DEFAULT_ENCODING = "utf-8"
DEVICE = "device"

MQTT_TRIGGER_BASE = {
    # Trigger when MQTT message is received
    CONF_PLATFORM: DEVICE,
    CONF_DOMAIN: DOMAIN,
}

TRIGGER_SCHEMA = TRIGGER_BASE_SCHEMA.extend(
    {
        vol.Required(CONF_PLATFORM): DEVICE,
        vol.Required(CONF_DOMAIN): DOMAIN,
        vol.Required(CONF_DEVICE_ID): str,
        vol.Required(CONF_DISCOVERY_ID): str,
        vol.Required(CONF_TYPE): cv.string,
        vol.Required(CONF_SUBTYPE): cv.string,
    }
)

TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend(
    {
        vol.Required(CONF_AUTOMATION_TYPE): str,
        vol.Required(CONF_DEVICE): mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA,
        vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic,
        vol.Optional(CONF_PAYLOAD, default=None): vol.Any(None, cv.string),
        vol.Required(CONF_TYPE): cv.string,
        vol.Required(CONF_SUBTYPE): cv.string,
    },
    mqtt.validate_device_has_at_least_one_identifier,
)

DEVICE_TRIGGERS = "mqtt_device_triggers"


@attr.s(slots=True)
class TriggerInstance:
    """Attached trigger settings."""

    action: AutomationActionType = attr.ib()
    automation_info: dict = attr.ib()
    trigger: "Trigger" = attr.ib()
    remove: Optional[CALLBACK_TYPE] = attr.ib(default=None)

    async def async_attach_trigger(self):
        """Attach MQTT trigger."""
        mqtt_config = {
            automation_mqtt.CONF_TOPIC: self.trigger.topic,
            automation_mqtt.CONF_ENCODING: DEFAULT_ENCODING,
            automation_mqtt.CONF_QOS: self.trigger.qos,
        }
        if self.trigger.payload:
            mqtt_config[CONF_PAYLOAD] = self.trigger.payload

        if self.remove:
            self.remove()
        self.remove = await automation_mqtt.async_attach_trigger(
            self.trigger.hass, mqtt_config, self.action, self.automation_info,
        )


@attr.s(slots=True)
class Trigger:
    """Device trigger settings."""

    device_id: str = attr.ib()
    discovery_data: dict = attr.ib()
    hass: HomeAssistantType = attr.ib()
    payload: str = attr.ib()
    qos: int = attr.ib()
    remove_signal: Callable[[], None] = attr.ib()
    subtype: str = attr.ib()
    topic: str = attr.ib()
    type: str = attr.ib()
    trigger_instances: List[TriggerInstance] = attr.ib(factory=list)

    async def add_trigger(self, action, automation_info):
        """Add MQTT trigger."""
        instance = TriggerInstance(action, automation_info, self)
        self.trigger_instances.append(instance)

        if self.topic is not None:
            # If we know about the trigger, subscribe to MQTT topic
            await instance.async_attach_trigger()

        @callback
        def async_remove() -> None:
            """Remove trigger."""
            if instance not in self.trigger_instances:
                raise HomeAssistantError("Can't remove trigger twice")

            if instance.remove:
                instance.remove()
            self.trigger_instances.remove(instance)

        return async_remove

    async def update_trigger(self, config, discovery_hash, remove_signal):
        """Update MQTT device trigger."""
        self.remove_signal = remove_signal
        self.type = config[CONF_TYPE]
        self.subtype = config[CONF_SUBTYPE]
        self.payload = config[CONF_PAYLOAD]
        self.qos = config[CONF_QOS]
        topic_changed = self.topic != config[CONF_TOPIC]
        self.topic = config[CONF_TOPIC]

        # Unsubscribe+subscribe if this trigger is in use and topic has changed
        # If topic is same unsubscribe+subscribe will execute in the wrong order
        # because unsubscribe is done with help of async_create_task
        if topic_changed:
            for trig in self.trigger_instances:
                await trig.async_attach_trigger()

    def detach_trigger(self):
        """Remove MQTT device trigger."""
        # Mark trigger as unknown
        self.topic = None

        # Unsubscribe if this trigger is in use
        for trig in self.trigger_instances:
            if trig.remove:
                trig.remove()
                trig.remove = None


async def _update_device(hass, config_entry, config):
    """Update device registry."""
    device_registry = await hass.helpers.device_registry.async_get_registry()
    config_entry_id = config_entry.entry_id
    device_info = mqtt.device_info_from_config(config[CONF_DEVICE])

    if config_entry_id is not None and device_info is not None:
        device_info["config_entry_id"] = config_entry_id
        device_registry.async_get_or_create(**device_info)


async def async_setup_trigger(hass, config, config_entry, discovery_data):
    """Set up the MQTT device trigger."""
    config = TRIGGER_DISCOVERY_SCHEMA(config)
    discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
    discovery_id = discovery_hash[1]
    remove_signal = None

    async def discovery_update(payload):
        """Handle discovery update."""
        _LOGGER.info(
            "Got update for trigger with hash: %s '%s'", discovery_hash, payload
        )
        if not payload:
            # Empty payload: Remove trigger
            _LOGGER.info("Removing trigger: %s", discovery_hash)
            debug_info.remove_trigger_discovery_data(hass, discovery_hash)
            if discovery_id in hass.data[DEVICE_TRIGGERS]:
                device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
                device_trigger.detach_trigger()
                clear_discovery_hash(hass, discovery_hash)
                remove_signal()
                await cleanup_device_registry(hass, device.id)
        else:
            # Non-empty payload: Update trigger
            _LOGGER.info("Updating trigger: %s", discovery_hash)
            debug_info.update_trigger_discovery_data(hass, discovery_hash, payload)
            config = TRIGGER_DISCOVERY_SCHEMA(payload)
            await _update_device(hass, config_entry, config)
            device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
            await device_trigger.update_trigger(config, discovery_hash, remove_signal)

    remove_signal = async_dispatcher_connect(
        hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update
    )

    await _update_device(hass, config_entry, config)

    device_registry = await hass.helpers.device_registry.async_get_registry()
    device = device_registry.async_get_device(
        {(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]},
        {tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]},
    )

    if device is None:
        return

    if DEVICE_TRIGGERS not in hass.data:
        hass.data[DEVICE_TRIGGERS] = {}
    if discovery_id not in hass.data[DEVICE_TRIGGERS]:
        hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
            hass=hass,
            device_id=device.id,
            discovery_data=discovery_data,
            type=config[CONF_TYPE],
            subtype=config[CONF_SUBTYPE],
            topic=config[CONF_TOPIC],
            payload=config[CONF_PAYLOAD],
            qos=config[CONF_QOS],
            remove_signal=remove_signal,
        )
    else:
        await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
            config, discovery_hash, remove_signal
        )
    debug_info.add_trigger_discovery_data(
        hass, discovery_hash, discovery_data, device.id
    )


async def async_device_removed(hass: HomeAssistant, device_id: str):
    """Handle the removal of a device."""
    triggers = await async_get_triggers(hass, device_id)
    for trig in triggers:
        device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID])
        if device_trigger:
            discovery_hash = device_trigger.discovery_data[ATTR_DISCOVERY_HASH]
            discovery_topic = device_trigger.discovery_data[ATTR_DISCOVERY_TOPIC]

            debug_info.remove_trigger_discovery_data(hass, discovery_hash)
            device_trigger.detach_trigger()
            clear_discovery_hash(hass, discovery_hash)
            device_trigger.remove_signal()
            mqtt.publish(
                hass, discovery_topic, "", retain=True,
            )


async def async_get_triggers(hass: HomeAssistant, device_id: str) -> List[dict]:
    """List device triggers for MQTT devices."""
    triggers = []

    if DEVICE_TRIGGERS not in hass.data:
        return triggers

    for discovery_id, trig in hass.data[DEVICE_TRIGGERS].items():
        if trig.device_id != device_id or trig.topic is None:
            continue

        trigger = {
            **MQTT_TRIGGER_BASE,
            "device_id": device_id,
            "type": trig.type,
            "subtype": trig.subtype,
            "discovery_id": discovery_id,
        }
        triggers.append(trigger)

    return triggers


async def async_attach_trigger(
    hass: HomeAssistant,
    config: ConfigType,
    action: AutomationActionType,
    automation_info: dict,
) -> CALLBACK_TYPE:
    """Attach a trigger."""
    if DEVICE_TRIGGERS not in hass.data:
        hass.data[DEVICE_TRIGGERS] = {}
    config = TRIGGER_SCHEMA(config)
    device_id = config[CONF_DEVICE_ID]
    discovery_id = config[CONF_DISCOVERY_ID]

    if discovery_id not in hass.data[DEVICE_TRIGGERS]:
        hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
            hass=hass,
            device_id=device_id,
            discovery_data=None,
            remove_signal=None,
            type=config[CONF_TYPE],
            subtype=config[CONF_SUBTYPE],
            topic=None,
            payload=None,
            qos=None,
        )
    return await hass.data[DEVICE_TRIGGERS][discovery_id].add_trigger(
        action, automation_info
    )
